{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# Use Flax modules as a part of the BrainPy program\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs_version2/tutorial_advanced/integrate_flax_into_brainpy.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs_version2/tutorial_advanced/integrate_flax_into_brainpy.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-20T14:58:22.773685400Z",
     "start_time": "2023-05-20T14:58:20.859311700Z"
    },
    "collapsed": false,
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "import brainpy as bp\n",
    "import brainpy.math as bm\n",
    "import brainpy_datasets as bd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-20T14:58:22.789897100Z",
     "start_time": "2023-05-20T14:58:22.775687200Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "from flax import linen as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-20T14:58:22.806933500Z",
     "start_time": "2023-05-20T14:58:22.790896Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "bm.set(mode=bm.training_mode, dt=1.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-20T15:13:00.928515700Z",
     "start_time": "2023-05-20T15:13:00.912573Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.4.1'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bp.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "In this example, we use the [Flax](https://github.com/google/flax), a library used for deep neural networks, to define a convolutional neural network (CNN). The, we integrate this CNN model into our RNN model which defined by BrainPy's syntax."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Here, we first use **flax** to define a CNN network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-20T14:58:22.820077600Z",
     "start_time": "2023-05-20T14:58:22.808986800Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "  \"\"\"A CNN model implemented by using Flax.\"\"\"\n",
    "\n",
    "  @nn.compact\n",
    "  def __call__(self, x):\n",
    "    x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n",
    "    x = nn.relu(x)\n",
    "    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
    "    x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n",
    "    x = nn.relu(x)\n",
    "    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
    "    x = x.reshape((x.shape[0], -1))  # flatten\n",
    "    x = nn.Dense(features=256)(x)\n",
    "    x = nn.relu(x)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Then, we define an RNN model by using our BrainPy interface. Note here, the Flax module is used as a module at one single step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-20T14:58:22.838587100Z",
     "start_time": "2023-05-20T14:58:22.821079400Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class Network(bp.DynamicalSystemNS):\n",
    "  def __init__(self):\n",
    "    super(Network, self).__init__()\n",
    "    self.cnn = bp.dnn.FromFlax(\n",
    "      CNN(), # the model\n",
    "      bm.ones([1, 4, 28, 1])  # an example of the input used to initialize the model parameters\n",
    "    )\n",
    "    self.rnn = bp.dyn.GRUCell(256, 100)\n",
    "    self.linear = bp.dnn.Dense(100, 10)\n",
    "\n",
    "  def update(self, x):\n",
    "    x = self.cnn(x)\n",
    "    x = self.rnn(x)\n",
    "    x = self.linear(x)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "We initialize the network, optimizer, loss function, and BP trainer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-20T14:58:24.465237300Z",
     "start_time": "2023-05-20T14:58:22.836586800Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    }
   ],
   "source": [
    "net = Network()\n",
    "opt = bp.optim.Momentum(0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "We get the MNIST dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-20T14:58:24.589939500Z",
     "start_time": "2023-05-20T14:58:24.466823Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "data = bd.vision.MNIST(r'D:\\data', download=True)\n",
    "data.data = data.data.reshape(-1, 7, 4, 28, 1) / 255\n",
    "\n",
    "\n",
    "def get_data(batch_size):\n",
    "  key = bm.random.split_key()\n",
    "  data.data = bm.random.permutation(data.data, key=key)\n",
    "  data.targets = bm.random.permutation(data.targets, key=key)\n",
    "\n",
    "  for i in range(0, len(data), batch_size):\n",
    "    yield data.data[i: i + batch_size], data.targets[i: i + batch_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-20T14:58:24.605264100Z",
     "start_time": "2023-05-20T14:58:24.589939500Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def loss_func(predictions, targets):\n",
    "  logits = bm.max(predictions, axis=1)\n",
    "  loss = bp.losses.cross_entropy_loss(logits, targets)\n",
    "  accuracy = bm.mean(bm.argmax(logits, -1) == targets)\n",
    "  return loss, {'accuracy': accuracy}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Finally, train our defined model by using ``BPTT.fit()`` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-20T15:13:00.912573Z",
     "start_time": "2023-05-20T14:58:24.606320200Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train 0 epoch, use 104.2070 s, loss 1.0793957710266113, accuracy 0.616583526134491\n",
      "Train 1 epoch, use 85.4961 s, loss 0.4177210330963135, accuracy 0.8495622277259827\n",
      "Train 2 epoch, use 85.1781 s, loss 0.27014848589897156, accuracy 0.9093307256698608\n",
      "Train 3 epoch, use 85.4031 s, loss 0.23874548077583313, accuracy 0.9184618592262268\n",
      "Train 4 epoch, use 86.0905 s, loss 0.21281874179840088, accuracy 0.925542950630188\n",
      "Train 5 epoch, use 85.5581 s, loss 0.19409772753715515, accuracy 0.9322085380554199\n",
      "Train 6 epoch, use 85.9805 s, loss 0.18303607404232025, accuracy 0.9356383085250854\n",
      "Train 7 epoch, use 85.0740 s, loss 0.16687186062335968, accuracy 0.9404421448707581\n",
      "Train 8 epoch, use 85.7086 s, loss 0.1607382893562317, accuracy 0.9421210289001465\n",
      "Train 9 epoch, use 87.4538 s, loss 0.15550467371940613, accuracy 0.9443760514259338\n"
     ]
    }
   ],
   "source": [
    "trainer = bp.BPTT(net, loss_fun=loss_func, optimizer=opt, loss_has_aux=True)\n",
    "trainer.fit(partial(get_data, batch_size=256), num_epoch=10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
